{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "e8f59f4c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'S-1.82=-89.17+9.82+77.53E'"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import random\n",
    "\n",
    "\n",
    "class Tokenizer:\n",
    "\n",
    "    def __init__(self):\n",
    "        self.vocab = {\n",
    "            'mark': list('PSEU'),\n",
    "            'number': list('0123456789'),\n",
    "            'symbol': list('+-*/'),\n",
    "            'other': list('.:=_')\n",
    "        }\n",
    "\n",
    "        self.decoder = [j for i in self.vocab.values() for j in i]\n",
    "        self.encoder = {j: i for i, j in enumerate(self.decoder)}\n",
    "\n",
    "    def get_data(self, third_number):\n",
    "        question = ''\n",
    "        for i in range(2):\n",
    "            question += '%.2f' % random.uniform(-100, 100)\n",
    "            question += random.choice(self.vocab['symbol'])\n",
    "\n",
    "        question = question[:-1]\n",
    "        if third_number:\n",
    "            question += '+%.2f' % random.uniform(-100, 100)\n",
    "\n",
    "        try:\n",
    "            answer = '%.2f' % eval(question)\n",
    "        except:\n",
    "            answer = '0.00'\n",
    "\n",
    "        #交换问答方向\n",
    "        question, answer = answer, question\n",
    "\n",
    "        token = 'S' + question + '=' + answer + 'E'\n",
    "        token = [self.encoder[i] for i in token]\n",
    "        return token\n",
    "\n",
    "    def decode(self, token):\n",
    "        return ''.join([self.decoder[i] for i in token])\n",
    "\n",
    "\n",
    "tokenizer = Tokenizer()\n",
    "\n",
    "tokenizer.decode(tokenizer.get_data(third_number=True))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "b7eb136f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'cuda'"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "\n",
    "device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "d65b850f",
   "metadata": {},
   "outputs": [],
   "source": [
    "class ModelGEN(torch.nn.Module):\n",
    "\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        from transformers import LlamaConfig, LlamaModel\n",
    "\n",
    "        self.config = LlamaConfig(hidden_size=64,\n",
    "                                  intermediate_size=64,\n",
    "                                  max_position_embeddings=128,\n",
    "                                  num_attention_heads=4,\n",
    "                                  num_hidden_layers=4,\n",
    "                                  num_key_value_heads=4,\n",
    "                                  vocab_size=len(tokenizer.decoder))\n",
    "\n",
    "        self.feature = LlamaModel(self.config)\n",
    "        self.fc_out = torch.nn.Linear(64, self.config.vocab_size, bias=False)\n",
    "\n",
    "        self.to(device)\n",
    "        self.train()\n",
    "\n",
    "    def forward(self, input_ids, attention_mask):\n",
    "        out = self.feature(input_ids=input_ids,\n",
    "                           attention_mask=attention_mask).last_hidden_state\n",
    "\n",
    "        return self.fc_out(out)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "a77c750e",
   "metadata": {},
   "outputs": [],
   "source": [
    "generater = None\n",
    "\n",
    "\n",
    "def generate(model_gen, input_ids):\n",
    "    global generater\n",
    "    if not generater:\n",
    "        #包装类,用于生成\n",
    "        from transformers import AutoModelForCausalLM\n",
    "        generater = AutoModelForCausalLM.from_config(model_gen.config)\n",
    "        generater.model = model_gen.feature\n",
    "        generater.lm_head = model_gen.fc_out\n",
    "        generater.to(device)\n",
    "\n",
    "    return generater.generate(input_ids=input_ids,\n",
    "                              min_length=-1,\n",
    "                              top_k=0.0,\n",
    "                              top_p=1.0,\n",
    "                              do_sample=True,\n",
    "                              pad_token_id=tokenizer.encoder['P'],\n",
    "                              max_new_tokens=35,\n",
    "                              eos_token_id=tokenizer.encoder['E'])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:pt2]",
   "language": "python",
   "name": "conda-env-pt2-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
